File size: 5,093 Bytes
9cb4953
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Deepfake Explainer Inference Helper (Multi-Image Version)
import torch
import os
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from peft import PeftModel
import matplotlib.pyplot as plt

# Function to analyze images for deepfakes
def analyze_deepfake(original_image_path, cam_image_path=None, cam_overlay_path=None, comparison_image_path=None, custom_prompt=None):
    """Analyze multiple images for signs of deepfakes with detailed explanation"""
    # Set up model paths
    base_model_id = "unsloth/llama-3.2-11b-vision-instruct"
    adapter_id = "saakshigupta/deepfake-explainer-2"

    # Load processor
    processor = AutoProcessor.from_pretrained(base_model_id)

    # Load base model
    model = AutoModelForCausalLM.from_pretrained(
        base_model_id,
        torch_dtype=torch.float16,
        device_map="auto"
    )

    # Load adapter
    model = PeftModel.from_pretrained(model, adapter_id)

    # Load images
    original_image = Image.open(original_image_path).convert("RGB")
    
    # Load additional images if available
    images = [original_image]
    image_titles = ["Original Image"]
    
    if cam_image_path:
        cam_image = Image.open(cam_image_path).convert("RGB")
        images.append(cam_image)
        image_titles.append("CAM Image")
        
    if cam_overlay_path:
        cam_overlay = Image.open(cam_overlay_path).convert("RGB")
        images.append(cam_overlay)
        image_titles.append("CAM Overlay")
        
    if comparison_image_path:
        comparison_image = Image.open(comparison_image_path).convert("RGB")
        images.append(comparison_image)
        image_titles.append("Comparison Image")
    
    # Display the images
    rows = (len(images) + 1) // 2
    fig, axs = plt.subplots(rows, min(2, len(images)), figsize=(12, 6 * rows))
    
    if len(images) == 1:
        axs.imshow(images[0])
        axs.set_title(image_titles[0])
        axs.axis('off')
    else:
        axs = axs.flatten() if len(images) > 2 else axs
        for i, (img, title) in enumerate(zip(images, image_titles)):
            if len(images) <= 2:
                curr_ax = axs[i]
            else:
                curr_ax = axs[i]
            curr_ax.imshow(img)
            curr_ax.set_title(title)
            curr_ax.axis('off')
    
    plt.tight_layout()
    plt.show()

    # Create prompt
    if custom_prompt is None:
        prompt = "Analyze these images carefully and determine if there's a deepfake. Provide both a technical explanation and a simple explanation anyone can understand."
    else:
        prompt = custom_prompt

    # Process the image and text
    try:
        # Try with multiple images
        inputs = processor(text=prompt, images=images, return_tensors="pt")
    except Exception as e:
        print(f"Warning: Unable to process multiple images ({e}). Falling back to original image only.")
        inputs = processor(text=prompt, images=original_image, return_tensors="pt")

    # Fix cross-attention mask
    if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
        batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
        visual_features = 6404  # Critical dimension from training
        new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles),
                           device=inputs['cross_attention_mask'].device if torch.is_tensor(inputs['cross_attention_mask']) else None)
        inputs['cross_attention_mask'] = new_mask
        print("Fixed cross-attention mask dimensions")

    # Move to device
    inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}

    # Generate output
    print("Generating analysis...")
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9
        )

    # Decode output
    response = processor.decode(output_ids[0], skip_special_tokens=True)

    # Extract the model's response (removing the prompt)
    if prompt in response:
        result = response.split(prompt)[-1].strip()
    else:
        result = response

    return result

# Example usage
if __name__ == "__main__":
    # Get image paths from user
    original_image_path = input("Enter path to original image: ")
    
    cam_image_path = input("Enter path to CAM image (or press Enter to skip): ")
    if cam_image_path.strip() == "":
        cam_image_path = None
        
    cam_overlay_path = input("Enter path to CAM overlay image (or press Enter to skip): ")
    if cam_overlay_path.strip() == "":
        cam_overlay_path = None
        
    comparison_image_path = input("Enter path to comparison image (or press Enter to skip): ")
    if comparison_image_path.strip() == "":
        comparison_image_path = None

    # Analyze images
    analysis = analyze_deepfake(original_image_path, cam_image_path, cam_overlay_path, comparison_image_path)

    # Print result
    print("
===== DEEPFAKE ANALYSIS RESULT =====
")
    print(analysis)