# Deepfake Explainer Inference Helper (Multi-Image Version) import torch import os from PIL import Image from transformers import AutoProcessor, AutoModelForCausalLM from peft import PeftModel import matplotlib.pyplot as plt # Function to analyze images for deepfakes def analyze_deepfake(original_image_path, cam_image_path=None, cam_overlay_path=None, comparison_image_path=None, custom_prompt=None): """Analyze multiple images for signs of deepfakes with detailed explanation""" # Set up model paths base_model_id = "unsloth/llama-3.2-11b-vision-instruct" adapter_id = "saakshigupta/deepfake-explainer-2" # Load processor processor = AutoProcessor.from_pretrained(base_model_id) # Load base model model = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=torch.float16, device_map="auto" ) # Load adapter model = PeftModel.from_pretrained(model, adapter_id) # Load images original_image = Image.open(original_image_path).convert("RGB") # Load additional images if available images = [original_image] image_titles = ["Original Image"] if cam_image_path: cam_image = Image.open(cam_image_path).convert("RGB") images.append(cam_image) image_titles.append("CAM Image") if cam_overlay_path: cam_overlay = Image.open(cam_overlay_path).convert("RGB") images.append(cam_overlay) image_titles.append("CAM Overlay") if comparison_image_path: comparison_image = Image.open(comparison_image_path).convert("RGB") images.append(comparison_image) image_titles.append("Comparison Image") # Display the images rows = (len(images) + 1) // 2 fig, axs = plt.subplots(rows, min(2, len(images)), figsize=(12, 6 * rows)) if len(images) == 1: axs.imshow(images[0]) axs.set_title(image_titles[0]) axs.axis('off') else: axs = axs.flatten() if len(images) > 2 else axs for i, (img, title) in enumerate(zip(images, image_titles)): if len(images) <= 2: curr_ax = axs[i] else: curr_ax = axs[i] curr_ax.imshow(img) curr_ax.set_title(title) curr_ax.axis('off') plt.tight_layout() plt.show() # Create prompt if custom_prompt is None: prompt = "Analyze these images carefully and determine if there's a deepfake. Provide both a technical explanation and a simple explanation anyone can understand." else: prompt = custom_prompt # Process the image and text try: # Try with multiple images inputs = processor(text=prompt, images=images, return_tensors="pt") except Exception as e: print(f"Warning: Unable to process multiple images ({e}). Falling back to original image only.") inputs = processor(text=prompt, images=original_image, return_tensors="pt") # Fix cross-attention mask if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape: batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape visual_features = 6404 # Critical dimension from training new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles), device=inputs['cross_attention_mask'].device if torch.is_tensor(inputs['cross_attention_mask']) else None) inputs['cross_attention_mask'] = new_mask print("Fixed cross-attention mask dimensions") # Move to device inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)} # Generate output print("Generating analysis...") with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=512, temperature=0.7, top_p=0.9 ) # Decode output response = processor.decode(output_ids[0], skip_special_tokens=True) # Extract the model's response (removing the prompt) if prompt in response: result = response.split(prompt)[-1].strip() else: result = response return result # Example usage if __name__ == "__main__": # Get image paths from user original_image_path = input("Enter path to original image: ") cam_image_path = input("Enter path to CAM image (or press Enter to skip): ") if cam_image_path.strip() == "": cam_image_path = None cam_overlay_path = input("Enter path to CAM overlay image (or press Enter to skip): ") if cam_overlay_path.strip() == "": cam_overlay_path = None comparison_image_path = input("Enter path to comparison image (or press Enter to skip): ") if comparison_image_path.strip() == "": comparison_image_path = None # Analyze images analysis = analyze_deepfake(original_image_path, cam_image_path, cam_overlay_path, comparison_image_path) # Print result print(" ===== DEEPFAKE ANALYSIS RESULT ===== ") print(analysis)