import gradio as gr import numpy as np from PIL import Image, ImageDraw import matplotlib.pyplot as plt from scipy.signal import convolve2d import time np.random.seed(123) # --- Configuration --- BOX_SIZE = 128 # --- Helper Functions (GetMap, getFourier) --- # These computational functions remain the same as before. def GetMap(image_channel): f = image_channel.astype(np.float64) alpha = np.random.rand(4) sigma = 0.005 delta = 10 count = 0 max_iter = 30 tolerance = 0.01 while True: count += 1 kernel = np.array([[0, alpha[0], 0], [alpha[1], 0, alpha[2]], [0, alpha[3], 0]]) filtered = convolve2d(f, kernel, mode='same', boundary='symm') r = np.abs(f - filtered) / 255.0 r = r[1:-1, 1:-1] e = np.exp(-r**2 / sigma) w = e / (e + 1 / delta) if np.sum(w) < 1e-6: print("Warning: Sum of weights is near zero. Exiting GetMap early.") return e.ravel() w_flat = w.ravel() W = np.diag(w_flat) value1 = f[:-2, 1:-1].ravel(); value2 = f[1:-1, :-2].ravel(); value3 = f[1:-1, 2:].ravel(); value4 = f[2:, 1:-1].ravel() Y = f[1:-1, 1:-1].ravel() X = np.column_stack((value1, value2, value3, value4)) try: alpha_new = np.linalg.inv(X.T @ W @ X) @ (X.T @ W @ Y) except np.linalg.LinAlgError: print("Warning: Singular matrix encountered. Cannot compute inverse.") return e.ravel() if np.linalg.norm(alpha - alpha_new) < tolerance or count > max_iter: break alpha = alpha_new sigma = np.sum(w * (r**2)) / np.sum(w) return e.ravel() def getFourier(prob): #imFft = np.fft.fftshift(np.fft.fft2(prob)) #imFft = np.abs(imFft) #if np.max(imFft) > 0: # imFft = (imFft / np.max(imFft) * 255) #imFft = imFft.astype(np.uint8) #imFft = (imFft > (0.5 * 255)).astype(np.uint8) # Compute the Fourier Transform and shift zero frequency to the center imFft = np.fft.fftshift(np.fft.fft2(prob)) # Take the magnitude (absolute value) imFft = np.abs(imFft) # Convert to 8-bit unsigned integer (similar to uint8 in MATLAB) imFft = np.uint8(imFft)/255 # Binarize the image with a threshold of 0.5 imFft = (imFft > 0.5).astype(np.uint8) return imFft # --- New Gradio Interaction Functions --- def draw_box_on_image(image: np.ndarray, box_coords: tuple, color="red", width=3) -> np.ndarray: """Draws a bounding box on a NumPy image array.""" pil_image = Image.fromarray(image) draw = ImageDraw.Draw(pil_image) x, y = box_coords rectangle = (x, y, x + BOX_SIZE, y + BOX_SIZE) draw.rectangle(rectangle, outline=color, width=width) return np.array(pil_image) def on_upload_image(image: np.ndarray) -> tuple: """Called when an image is first uploaded. Stores the original image and draws the initial box.""" initial_coords = (0, 0) image_with_box = draw_box_on_image(image, initial_coords) # Returns: (image_with_box for display, original_image for state, initial_coords for state) return image_with_box, image, initial_coords def move_selection_box(evt: gr.SelectData, original_image: np.ndarray) -> tuple: """Called when the user clicks the image. It moves the box to the clicked location.""" # Center the box on the user's click x = evt.index[0] - BOX_SIZE // 2 y = evt.index[1] - BOX_SIZE // 2 # Clamp coordinates to ensure the box stays within the image boundaries img_h, img_w, _ = original_image.shape x = max(0, min(x, img_w - BOX_SIZE)) y = max(0, min(y, img_h - BOX_SIZE)) new_coords = (int(x), int(y)) image_with_box = draw_box_on_image(original_image, new_coords) # Returns: (image_with_box for display, new_coords for state) return image_with_box, new_coords def analyze_region(original_image: np.ndarray, box_coords: tuple): """The main analysis function, triggered by the 'Analyze' button.""" if original_image is None: gr.Warning("Please upload an image first!") return None print(f"\n--- Analysis Started for region at {box_coords} ---") start_time = time.time() x, y = box_coords patch = original_image[y:y + BOX_SIZE, x:x + BOX_SIZE] print(f"1. Patch extracted with shape: {patch.shape}") if len(patch.shape) == 3: analysis_channel = patch[:, :, 1] # Green channel else: analysis_channel = patch # Grayscale print("2. Computing probability map...") prob_flat = GetMap(analysis_channel) prob_map_shape = (analysis_channel.shape[0] - 2, analysis_channel.shape[1] - 2) prob_map = prob_flat.reshape(prob_map_shape) print("3. Computing Fourier transform...") fft_result = getFourier(prob_map) # Plotting fig, axs = plt.subplots(1, 3, figsize=(12, 4)) axs[0].imshow(patch); axs[0].set_title("Selected 128x128 Patch"); axs[0].axis("off") axs[1].imshow(prob_map, cmap='gray'); axs[1].set_title("Probability Map"); axs[1].axis("off") axs[2].imshow(np.abs(fft_result), cmap='gray'); axs[2].set_title("Fourier Transform"); axs[2].axis("off") plt.tight_layout() print(f"4. Analysis complete in {time.time() - start_time:.2f} seconds.") return fig def build_demo(): # --- Build the Gradio Interface using Blocks --- with gr.Blocks(theme=gr.themes.Soft()) as demo: # State variables store data (like the original image) between user interactions original_image_state = gr.State() box_coords_state = gr.State(value=(0, 0)) gr.Markdown("# 🎨 Color Filter Array Analysis") gr.Markdown( "Analyzes artifacts introduced during the camera's raw image processing. Inconsistencies in the **Color Filter Array (CFA)** interpolation pattern can reveal areas that have been spliced from another image or copy-pasted within the same image (copy-move).\n" "\n" "## Instructions:\n" "1. **Upload** an image.\n" "2. **Click** anywhere on the image to move the 128x128 selection box.\n" "3. Press the **Analyze Region** button to start processing." ) with gr.Row(): image_display = gr.Image(type="numpy", label="Selection Canvas", interactive=True) output_plot = gr.Plot(label="Analysis Results") analyze_button = gr.Button("Analyze Region", variant="primary") # --- Wire up the event listeners --- # 1. When a new image is uploaded, call on_upload_image image_display.upload( fn=on_upload_image, inputs=[image_display], outputs=[image_display, original_image_state, box_coords_state] ) # 2. When the user clicks the image, call move_selection_box image_display.select( fn=move_selection_box, inputs=[original_image_state], outputs=[image_display, box_coords_state] ) # 3. When the user clicks the analyze button, call analyze_region analyze_button.click( fn=analyze_region, inputs=[original_image_state, box_coords_state], outputs=[output_plot], # Show a progress bar during analysis show_progress="full" ) return demo if __name__ == "__main__": app = build_demo() app.launch()