File size: 7,403 Bytes
fc8df74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e9829e
 
 
 
 
 
 
fc8df74
8e9829e
 
fc8df74
8e9829e
 
 
 
 
 
fc8df74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df6f6d2
fc8df74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e9829e
fc8df74
 
 
 
 
8e9829e
fc8df74
 
 
 
 
40598e4
 
 
 
 
 
 
61793df
40598e4
61793df
5f9c46d
f9e90ec
40598e4
 
 
 
 
 
df6f6d2
40598e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from scipy.signal import convolve2d
import time
np.random.seed(123)

# --- Configuration ---
BOX_SIZE = 128

# --- Helper Functions (GetMap, getFourier) ---
# These computational functions remain the same as before.
def GetMap(image_channel):
    f = image_channel.astype(np.float64)
    alpha = np.random.rand(4)
    sigma = 0.005
    delta = 10
    count = 0
    max_iter = 30
    tolerance = 0.01
    while True:
        count += 1
        kernel = np.array([[0, alpha[0], 0], [alpha[1], 0, alpha[2]], [0, alpha[3], 0]])
        filtered = convolve2d(f, kernel, mode='same', boundary='symm')
        r = np.abs(f - filtered) / 255.0
        r = r[1:-1, 1:-1]
        e = np.exp(-r**2 / sigma)
        w = e / (e + 1 / delta)
        if np.sum(w) < 1e-6:
            print("Warning: Sum of weights is near zero. Exiting GetMap early.")
            return e.ravel()
        w_flat = w.ravel()
        W = np.diag(w_flat)
        value1 = f[:-2, 1:-1].ravel(); value2 = f[1:-1, :-2].ravel(); value3 = f[1:-1, 2:].ravel(); value4 = f[2:, 1:-1].ravel()
        Y = f[1:-1, 1:-1].ravel()
        X = np.column_stack((value1, value2, value3, value4))
        try:
            alpha_new = np.linalg.inv(X.T @ W @ X) @ (X.T @ W @ Y)
        except np.linalg.LinAlgError:
            print("Warning: Singular matrix encountered. Cannot compute inverse.")
            return e.ravel()
        if np.linalg.norm(alpha - alpha_new) < tolerance or count > max_iter: break
        alpha = alpha_new
        sigma = np.sum(w * (r**2)) / np.sum(w)
    return e.ravel()

def getFourier(prob):
    #imFft = np.fft.fftshift(np.fft.fft2(prob))
    #imFft = np.abs(imFft)
    #if np.max(imFft) > 0:
    #    imFft = (imFft / np.max(imFft) * 255)
    #imFft = imFft.astype(np.uint8)
    #imFft = (imFft > (0.5 * 255)).astype(np.uint8)
    # Compute the Fourier Transform and shift zero frequency to the center
    imFft = np.fft.fftshift(np.fft.fft2(prob))
    
    # Take the magnitude (absolute value)
    imFft = np.abs(imFft)
    
    # Convert to 8-bit unsigned integer (similar to uint8 in MATLAB)
    imFft = np.uint8(imFft)/255

    # Binarize the image with a threshold of 0.5
    imFft = (imFft > 0.5).astype(np.uint8)
    return imFft

# --- New Gradio Interaction Functions ---

def draw_box_on_image(image: np.ndarray, box_coords: tuple, color="red", width=3) -> np.ndarray:
    """Draws a bounding box on a NumPy image array."""
    pil_image = Image.fromarray(image)
    draw = ImageDraw.Draw(pil_image)
    x, y = box_coords
    rectangle = (x, y, x + BOX_SIZE, y + BOX_SIZE)
    draw.rectangle(rectangle, outline=color, width=width)
    return np.array(pil_image)

def on_upload_image(image: np.ndarray) -> tuple:
    """Called when an image is first uploaded. Stores the original image and draws the initial box."""
    initial_coords = (0, 0)
    image_with_box = draw_box_on_image(image, initial_coords)
    # Returns: (image_with_box for display, original_image for state, initial_coords for state)
    return image_with_box, image, initial_coords

def move_selection_box(evt: gr.SelectData, original_image: np.ndarray) -> tuple:
    """Called when the user clicks the image. It moves the box to the clicked location."""
    # Center the box on the user's click
    x = evt.index[0] - BOX_SIZE // 2
    y = evt.index[1] - BOX_SIZE // 2
    
    # Clamp coordinates to ensure the box stays within the image boundaries
    img_h, img_w, _ = original_image.shape
    x = max(0, min(x, img_w - BOX_SIZE))
    y = max(0, min(y, img_h - BOX_SIZE))
    
    new_coords = (int(x), int(y))
    image_with_box = draw_box_on_image(original_image, new_coords)
    # Returns: (image_with_box for display, new_coords for state)
    return image_with_box, new_coords

def analyze_region(original_image: np.ndarray, box_coords: tuple):
    """The main analysis function, triggered by the 'Analyze' button."""
    if original_image is None:
        gr.Warning("Please upload an image first!")
        return None
        
    print(f"\n--- Analysis Started for region at {box_coords} ---")
    start_time = time.time()
    
    x, y = box_coords
    patch = original_image[y:y + BOX_SIZE, x:x + BOX_SIZE]
    print(f"1. Patch extracted with shape: {patch.shape}")

    if len(patch.shape) == 3: analysis_channel = patch[:, :, 1] # Green channel
    else: analysis_channel = patch # Grayscale
    
    print("2. Computing probability map...")
    prob_flat = GetMap(analysis_channel)
    prob_map_shape = (analysis_channel.shape[0] - 2, analysis_channel.shape[1] - 2)
    prob_map = prob_flat.reshape(prob_map_shape)
    
    print("3. Computing Fourier transform...")
    fft_result = getFourier(prob_map)

    # Plotting
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].imshow(patch); axs[0].set_title("Selected 128x128 Patch"); axs[0].axis("off")
    axs[1].imshow(prob_map, cmap='gray'); axs[1].set_title("Probability Map"); axs[1].axis("off")
    axs[2].imshow(np.abs(fft_result), cmap='gray'); axs[2].set_title("Fourier Transform"); axs[2].axis("off")
    plt.tight_layout()
    
    print(f"4. Analysis complete in {time.time() - start_time:.2f} seconds.")
    return fig

def build_demo():
    # --- Build the Gradio Interface using Blocks ---
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        # State variables store data (like the original image) between user interactions
        original_image_state = gr.State()
        box_coords_state = gr.State(value=(0, 0))

        gr.Markdown("# 🎨 Color Filter Array Analysis")
        gr.Markdown(
            "Analyzes artifacts introduced during the camera's raw image processing. Inconsistencies in the **Color Filter Array (CFA)** interpolation pattern can reveal areas that have been spliced from another image or copy-pasted within the same image (copy-move).\n"
            "\n"
            "## Instructions:\n"
            "1. **Upload** an image.\n"
            "2. **Click** anywhere on the image to move the 128x128 selection box.\n"
            "3. Press the **Analyze Region** button to start processing."
        )

        with gr.Row():
            image_display = gr.Image(type="numpy", label="Selection Canvas", interactive=True)
            output_plot = gr.Plot(label="Analysis Results")

        analyze_button = gr.Button("Analyze Region", variant="primary")

        # --- Wire up the event listeners ---

        # 1. When a new image is uploaded, call on_upload_image
        image_display.upload(
            fn=on_upload_image,
            inputs=[image_display],
            outputs=[image_display, original_image_state, box_coords_state]
        )

        # 2. When the user clicks the image, call move_selection_box
        image_display.select(
            fn=move_selection_box,
            inputs=[original_image_state],
            outputs=[image_display, box_coords_state]
        )
        
        # 3. When the user clicks the analyze button, call analyze_region
        analyze_button.click(
            fn=analyze_region,
            inputs=[original_image_state, box_coords_state],
            outputs=[output_plot],
            # Show a progress bar during analysis
            show_progress="full" 
        )
    return demo

if __name__ == "__main__":
    app = build_demo()
    app.launch()