Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import random | |
| import numpy as np | |
| import gradio as gr | |
| try: | |
| from tensorflow.keras.models import Model | |
| from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input | |
| except ImportError: | |
| try: | |
| from keras.models import Model | |
| from keras.applications.vgg19 import VGG19, preprocess_input | |
| except ImportError: | |
| pass | |
| import matplotlib.pyplot as plt | |
| from scipy.special import kl_div as scipy_kl_div | |
| from skimage.metrics import structural_similarity as ssim | |
| import warnings | |
| # --- Configuration --- | |
| # Set the default task. | |
| TASK = "facades" | |
| PATH = os.path.join("datasets", TASK, "real") | |
| images = [] | |
| perceptual_model = None | |
| # --- Model Loading --- | |
| # Attempt to load the VGG19 model for the perceptual loss metric. | |
| try: | |
| vgg = VGG19(weights='imagenet', include_top=False, input_shape=(224, 224, 3)) | |
| vgg.trainable = False | |
| perceptual_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block5_conv4').output, name="perceptual_model") | |
| except Exception as e: | |
| perceptual_model = None | |
| # --- Utility Functions --- | |
| def safe_normalize_heatmap(heatmap): | |
| """Safely normalizes a heatmap to a 0-255 range for visualization.""" | |
| if heatmap is None or heatmap.size == 0: | |
| return np.zeros((64, 64), dtype=np.uint8) | |
| heatmap = heatmap.astype(np.float32) | |
| if not np.all(np.isfinite(heatmap)): | |
| min_val_safe = np.nanmin(heatmap[np.isfinite(heatmap)]) if np.any(np.isfinite(heatmap)) else 0 | |
| max_val_safe = np.nanmax(heatmap[np.isfinite(heatmap)]) if np.any(np.isfinite(heatmap)) else 0 | |
| heatmap = np.nan_to_num(heatmap, nan=0.0, posinf=max_val_safe, neginf=min_val_safe) | |
| min_val = np.min(heatmap) | |
| max_val = np.max(heatmap) | |
| range_val = max_val - min_val | |
| normalized_heatmap = np.zeros_like(heatmap, dtype=np.float32) | |
| if range_val > 1e-9: | |
| normalized_heatmap = ((heatmap - min_val) / range_val) * 255.0 | |
| normalized_heatmap = np.clip(normalized_heatmap, 0, 255) | |
| return np.uint8(normalized_heatmap) | |
| # --- Image Comparison Metrics --- | |
| def KL_divergence(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): | |
| """Calculates Kullback-Leibler Divergence between two images.""" | |
| if img_real is None or img_fake is None or img_real.shape != img_fake.shape: | |
| return None | |
| try: | |
| img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
| img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
| except cv2.error: | |
| return None | |
| height, width, channels = img_real_rgb.shape | |
| img_dict = { | |
| "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, | |
| "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, | |
| "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, | |
| "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} | |
| } | |
| channel_keys = ["R", "G", "B"] | |
| current_block_size = max(1, int(block_size)) | |
| if current_block_size > min(height, width): | |
| current_block_size = min(height, width) | |
| for channel_idx, key in enumerate(channel_keys): | |
| channel_sum = 0.0 | |
| for i in range(0, height - current_block_size + 1, current_block_size): | |
| for j in range(0, width - current_block_size + 1, current_block_size): | |
| block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() + epsilon | |
| block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() + epsilon | |
| if np.sum(block_gt) > 0 and np.sum(block_pred) > 0: | |
| block_gt_norm = block_gt / np.sum(block_gt) | |
| block_pred_norm = block_pred / np.sum(block_pred) | |
| kl_values = scipy_kl_div(block_gt_norm, block_pred_norm) | |
| kl_values = np.nan_to_num(kl_values, nan=0.0, posinf=0.0, neginf=0.0) | |
| kl_sum_block = np.sum(kl_values) | |
| if np.isfinite(kl_sum_block): | |
| channel_sum += kl_sum_block | |
| mean_kl_block = kl_sum_block / max(1, current_block_size * current_block_size) | |
| img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = mean_kl_block | |
| if sum_channels: | |
| img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += mean_kl_block | |
| img_dict[key]["SUM"] = channel_sum | |
| if sum_channels: | |
| img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
| img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
| return img_dict | |
| def L1_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): | |
| """Calculates L1 (Mean Absolute Error) loss between two images.""" | |
| if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None | |
| try: | |
| img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
| img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
| except cv2.error: return None | |
| height, width, channels = img_real_rgb.shape | |
| img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } | |
| channel_keys = ["R", "G", "B"] | |
| current_block_size = max(1, int(block_size)) | |
| if current_block_size > min(height, width): current_block_size = min(height, width) | |
| for channel_idx, key in enumerate(channel_keys): | |
| channel_sum = 0.0 | |
| for i in range(0, height - current_block_size + 1, current_block_size): | |
| for j in range(0, width - current_block_size + 1, current_block_size): | |
| block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
| block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
| result_block = np.abs(block_pred - block_gt) | |
| sum_result_block = np.sum(result_block) | |
| channel_sum += sum_result_block | |
| mean_l1_block = sum_result_block / max(1, current_block_size * current_block_size) | |
| img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = mean_l1_block | |
| if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += mean_l1_block | |
| img_dict[key]["SUM"] = channel_sum | |
| if sum_channels: | |
| img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
| img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
| return img_dict | |
| def MSE_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): | |
| """Calculates MSE (Mean Squared Error) loss between two images.""" | |
| if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None | |
| try: | |
| img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
| img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
| except cv2.error: return None | |
| height, width, channels = img_real_rgb.shape | |
| img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } | |
| channel_keys = ["R", "G", "B"] | |
| current_block_size = max(1, int(block_size)) | |
| if current_block_size > min(height, width): current_block_size = min(height, width) | |
| for channel_idx, key in enumerate(channel_keys): | |
| channel_sum = 0.0 | |
| for i in range(0, height - current_block_size + 1, current_block_size): | |
| for j in range(0, width - current_block_size + 1, current_block_size): | |
| block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
| block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
| result_block = np.square(block_pred - block_gt) | |
| sum_result_block = np.sum(result_block) | |
| channel_sum += sum_result_block | |
| mean_mse_block = sum_result_block / max(1, current_block_size * current_block_size) | |
| img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = mean_mse_block | |
| if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += mean_mse_block | |
| img_dict[key]["SUM"] = channel_sum | |
| if sum_channels: | |
| img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
| img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
| return img_dict | |
| def SSIM_loss(img_real, img_fake, block_size=7, sum_channels=False): | |
| """Calculates SSIM (Structural Similarity Index) loss between two images.""" | |
| if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None | |
| try: | |
| img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB) | |
| img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB) | |
| except cv2.error: return None | |
| height, width, channels = img_real_rgb.shape | |
| img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } | |
| channel_keys = ["R", "G", "B"] | |
| for channel_idx, key in enumerate(channel_keys): | |
| win_size = int(block_size) | |
| if win_size % 2 == 0: win_size += 1 | |
| win_size = max(3, min(win_size, height, width)) | |
| try: | |
| _, ssim_map = ssim(img_real_rgb[:, :, channel_idx], img_fake_rgb[:, :, channel_idx], win_size=win_size, data_range=255, full=True, gaussian_weights=True) | |
| ssim_loss_map = np.maximum(0.0, 1.0 - ssim_map) | |
| img_dict[key]["SUM"] = np.sum(ssim_loss_map) | |
| img_dict[key]["HEATMAP"] = ssim_loss_map | |
| if sum_channels: img_dict["SUM"]["HEATMAP"] += ssim_loss_map | |
| except ValueError: | |
| img_dict[key]["SUM"] = 0.0 | |
| img_dict[key]["HEATMAP"] = np.zeros((height, width), dtype=np.float32) | |
| if sum_channels: | |
| img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
| img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
| return img_dict | |
| def cosine_similarity_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): | |
| """Calculates Cosine Similarity loss between two images.""" | |
| if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None | |
| try: | |
| img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
| img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
| except cv2.error: return None | |
| height, width, channels = img_real_rgb.shape | |
| img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } | |
| channel_keys = ["R", "G", "B"] | |
| current_block_size = max(1, int(block_size)) | |
| if current_block_size > min(height, width): current_block_size = min(height, width) | |
| for channel_idx, key in enumerate(channel_keys): | |
| channel_sum = 0.0 | |
| for i in range(0, height - current_block_size + 1, current_block_size): | |
| for j in range(0, width - current_block_size + 1, current_block_size): | |
| block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() | |
| block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx].flatten() | |
| dot_product = np.dot(block_pred, block_gt) | |
| norm_pred = np.linalg.norm(block_pred) | |
| norm_gt = np.linalg.norm(block_gt) | |
| cosine_sim = 0.0 | |
| if norm_pred * norm_gt > epsilon: | |
| cosine_sim = dot_product / (norm_pred * norm_gt) | |
| elif norm_pred < epsilon and norm_gt < epsilon: | |
| cosine_sim = 1.0 | |
| result_block = 1.0 - np.clip(cosine_sim, -1.0, 1.0) | |
| channel_sum += result_block | |
| img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = result_block | |
| if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += result_block | |
| img_dict[key]["SUM"] = channel_sum | |
| if sum_channels: | |
| img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
| img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
| return img_dict | |
| def TV_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False): | |
| """Calculates Total Variation (TV) loss between two images.""" | |
| if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None | |
| try: | |
| img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
| img_fake_rgb = cv2.cvtColor(img_fake, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
| except cv2.error: return None | |
| height, width, channels = img_real_rgb.shape | |
| img_dict = { "R": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "G": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "B": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)}, "SUM": {"SUM": 0.0, "HEATMAP": np.zeros((height, width), dtype=np.float32)} } | |
| channel_keys = ["R", "G", "B"] | |
| current_block_size = max(2, int(block_size)) | |
| if current_block_size > min(height, width): current_block_size = min(height, width) | |
| for channel_idx, key in enumerate(channel_keys): | |
| channel_sum = 0.0 | |
| for i in range(0, height - current_block_size + 1, current_block_size): | |
| for j in range(0, width - current_block_size + 1, current_block_size): | |
| block_pred = img_fake_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
| block_gt = img_real_rgb[i:i+current_block_size, j:j+current_block_size, channel_idx] | |
| tv_pred = np.sum(np.abs(block_pred[:, 1:] - block_pred[:, :-1])) + np.sum(np.abs(block_pred[1:, :] - block_pred[:-1, :])) | |
| tv_gt = np.sum(np.abs(block_gt[:, 1:] - block_gt[:, :-1])) + np.sum(np.abs(block_gt[1:, :] - block_gt[:-1, :])) | |
| result_block = np.abs(tv_pred - tv_gt) | |
| channel_sum += result_block | |
| img_dict[key]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] = result_block | |
| if sum_channels: img_dict["SUM"]["HEATMAP"][i:i+current_block_size, j:j+current_block_size] += result_block | |
| img_dict[key]["SUM"] = channel_sum | |
| if sum_channels: | |
| img_dict["SUM"]["SUM"] = img_dict["R"]["SUM"] + img_dict["G"]["SUM"] + img_dict["B"]["SUM"] | |
| img_dict["SUM"]["HEATMAP"] /= max(1, channels) | |
| return img_dict | |
| def perceptual_loss(img_real, img_fake, model, block_size=4): | |
| """Calculates Perceptual loss using a pre-trained VGG19 model.""" | |
| if img_real is None or img_fake is None or model is None or img_real.shape != img_fake.shape: | |
| return None | |
| original_height, original_width, _ = img_real.shape | |
| try: | |
| target_size = (model.input_shape[1], model.input_shape[2]) | |
| cv2_target_size = (target_size[1], target_size[0]) | |
| img_real_resized = cv2.resize(img_real, cv2_target_size, interpolation=cv2.INTER_AREA) | |
| img_fake_resized = cv2.resize(img_fake, cv2_target_size, interpolation=cv2.INTER_AREA) | |
| img_real_processed = preprocess_input(np.expand_dims(cv2.cvtColor(img_real_resized, cv2.COLOR_BGR2RGB), axis=0)) | |
| img_fake_processed = preprocess_input(np.expand_dims(cv2.cvtColor(img_fake_resized, cv2.COLOR_BGR2RGB), axis=0)) | |
| except Exception: | |
| return None | |
| try: | |
| img_real_vgg = model.predict(img_real_processed) | |
| img_fake_vgg = model.predict(img_fake_processed) | |
| except Exception: | |
| return None | |
| feature_mse = np.square(img_real_vgg - img_fake_vgg) | |
| total_loss = np.sum(feature_mse) | |
| heatmap_features = np.mean(feature_mse[0, :, :, :], axis=-1) | |
| heatmap_original_size = cv2.resize(heatmap_features, (original_width, original_height), interpolation=cv2.INTER_LINEAR) | |
| return {"SUM": {"SUM": total_loss, "HEATMAP": heatmap_original_size.astype(np.float32)}} | |
| # --- Gradio Core Functions --- | |
| def gather_images(task): | |
| """Loads a random pair of real and fake images from the selected dataset.""" | |
| global TASK, PATH, images | |
| new_path = os.path.join("datasets", task, "real") | |
| if TASK != task or not images: | |
| PATH = new_path | |
| TASK = task | |
| images = [] | |
| if not os.path.isdir(PATH): | |
| error_msg = f"Error: Directory for task '{task}' not found: {PATH}" | |
| placeholder = np.zeros((256, 256, 3), dtype=np.uint8) | |
| return placeholder, placeholder, error_msg | |
| try: | |
| valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff') | |
| images = [os.path.join(PATH, f) for f in os.listdir(PATH) if f.lower().endswith(valid_extensions)] | |
| if not images: | |
| error_msg = f"Error: No valid image files found in: {PATH}" | |
| placeholder = np.zeros((256, 256, 3), dtype=np.uint8) | |
| return placeholder, placeholder, error_msg | |
| except Exception as e: | |
| error_msg = f"Error reading directory {PATH}: {e}" | |
| placeholder = np.zeros((256, 256, 3), dtype=np.uint8) | |
| return placeholder, placeholder, error_msg | |
| if not images: | |
| error_msg = f"Error: No images available for task '{task}'." | |
| placeholder = np.zeros((256, 256, 3), dtype=np.uint8) | |
| return placeholder, placeholder, error_msg | |
| try: | |
| real_img_path = random.choice(images) | |
| img_filename = os.path.basename(real_img_path) | |
| fake_img_path = os.path.join("datasets", task, "fake", img_filename) | |
| real_img = cv2.imread(real_img_path) | |
| fake_img = cv2.imread(fake_img_path) | |
| placeholder_shape = (256, 256, 3) | |
| if real_img is None: | |
| return np.zeros(placeholder_shape, dtype=np.uint8), fake_img if fake_img is not None else np.zeros(placeholder_shape, dtype=np.uint8), f"Error: Failed to load real image: {real_img_path}" | |
| if fake_img is None: | |
| return real_img, np.zeros(real_img.shape, dtype=np.uint8), f"Error: Failed to load fake image: {fake_img_path}" | |
| if real_img.shape != fake_img.shape: | |
| target_dims = (real_img.shape[1], real_img.shape[0]) | |
| fake_img = cv2.resize(fake_img, target_dims, interpolation=cv2.INTER_AREA) | |
| return real_img, fake_img, f"Sample pair for '{task}' loaded successfully." | |
| except Exception as e: | |
| error_msg = f"An unexpected error occurred during image loading: {e}" | |
| placeholder = np.zeros((256, 256, 3), dtype=np.uint8) | |
| return placeholder, placeholder, error_msg | |
| def run_comparison(real, fake, measurement, block_size_val): | |
| """Runs the selected comparison metric and generates a heatmap.""" | |
| placeholder_heatmap = np.zeros((64, 64, 3), dtype=np.uint8) | |
| if real is None or fake is None or not isinstance(real, np.ndarray) or not isinstance(fake, np.ndarray): | |
| return placeholder_heatmap, "Error: Input image(s) missing or invalid. Please load or upload a pair of images." | |
| status_msg_prefix = "" | |
| if real.shape != fake.shape: | |
| status_msg_prefix = f"Warning: Input images have different shapes ({real.shape} vs {fake.shape}). Resizing fake image to match real. " | |
| target_dims = (real.shape[1], real.shape[0]) | |
| fake = cv2.resize(fake, target_dims, interpolation=cv2.INTER_AREA) | |
| result = None | |
| block_size_int = int(block_size_val) | |
| try: | |
| if measurement == "Kullback-Leibler Divergence": result = KL_divergence(real, fake, block_size=block_size_int, sum_channels=True) | |
| elif measurement == "L1-Loss": result = L1_loss(real, fake, block_size=block_size_int, sum_channels=True) | |
| elif measurement == "MSE": result = MSE_loss(real, fake, block_size=block_size_int, sum_channels=True) | |
| elif measurement == "SSIM": result = SSIM_loss(real, fake, block_size=block_size_int, sum_channels=True) | |
| elif measurement == "Cosine Similarity": result = cosine_similarity_loss(real, fake, block_size=block_size_int, sum_channels=True) | |
| elif measurement == "TV": result = TV_loss(real, fake, block_size=block_size_int, sum_channels=True) | |
| elif measurement == "Perceptual": | |
| if perceptual_model is None: | |
| return placeholder_heatmap, "Error: Perceptual model not loaded. Cannot calculate Perceptual loss." | |
| result = perceptual_loss(real, fake, model=perceptual_model, block_size=block_size_int) | |
| else: | |
| return placeholder_heatmap, f"Error: Unknown measurement '{measurement}'." | |
| except Exception as e: | |
| return placeholder_heatmap, f"Error during {measurement} calculation: {e}" | |
| if result is None or "SUM" not in result or "HEATMAP" not in result["SUM"]: | |
| return placeholder_heatmap, f"{measurement} calculation failed or returned an invalid result structure." | |
| heatmap_raw = result["SUM"]["HEATMAP"] | |
| if not isinstance(heatmap_raw, np.ndarray) or heatmap_raw.size == 0: | |
| return placeholder_heatmap, f"Generated heatmap is invalid or empty for {measurement}." | |
| try: | |
| heatmap_normalized = safe_normalize_heatmap(heatmap_raw) | |
| heatmap_color = cv2.applyColorMap(heatmap_normalized, cv2.COLORMAP_HOT) | |
| heatmap_rgb = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB) | |
| except Exception as e: | |
| return placeholder_heatmap, f"Error during heatmap coloring: {e}" | |
| status_msg = status_msg_prefix + f"{measurement} comparison successful." | |
| return heatmap_rgb, status_msg | |
| def clear_uploads(msg): | |
| """Clears the image displays and updates the status message.""" | |
| return None, None, msg | |
| def load_and_compare_initial(task): | |
| """Gathers initial images and runs a comparison on them at startup.""" | |
| # Step 1: Get the initial images | |
| real_img, fake_img, gather_status = gather_images(task) | |
| # Step 2: Run the default comparison | |
| # We use the default values from the UI definition | |
| default_measurement = "Cosine Similarity" | |
| default_block_size = 8 | |
| heatmap, compare_status = run_comparison(real_img, fake_img, default_measurement, default_block_size) | |
| # Step 3: Combine status messages and return all initial values | |
| final_status = f"{gather_status}\n{compare_status}" | |
| return real_img, fake_img, heatmap, final_status | |
| # --- Gradio UI Definition --- | |
| theme = gr.themes.Soft(primary_hue="blue", secondary_hue="orange") | |
| with gr.Blocks(theme=theme, css=".gradio-container { max-width: 1400px !important; margin: auto; }") as demo: | |
| gr.Markdown("# GAN vs Ground Truth Image Comparison") | |
| gr.Markdown("Compare images by loading a sample pair from a dataset or by uploading your own. Choose a comparison metric and run the analysis to see the difference heatmap.") | |
| status_message = gr.Textbox(label="Status / Errors", lines=2, interactive=False, show_copy_button=True) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("### 1. Get Images") | |
| with gr.Tabs(): | |
| with gr.TabItem("Load from Dataset"): | |
| task_dropdown = gr.Dropdown( | |
| ["facades"], value=TASK, | |
| info="Select the dataset task.", | |
| label="Dataset Task" | |
| ) | |
| sample_button = gr.Button("๐ Get New Sample Pair", variant="secondary") | |
| with gr.TabItem("Upload Images"): | |
| gr.Markdown("Upload your own images to compare.") | |
| upload_real_img = gr.Image(type="numpy", label="Upload Real/Reference Image") | |
| upload_fake_img = gr.Image(type="numpy", label="Upload Fake/Comparison Image") | |
| with gr.Column(scale=2, min_width=600): | |
| gr.Markdown("### 2. View Images & Run Comparison") | |
| with gr.Row(): | |
| real_img_display = gr.Image(type="numpy", label="Real Image (Ground Truth)", height=350, interactive=False) | |
| fake_img_display = gr.Image(type="numpy", label="Fake Image (Generated by GAN)", height=350, interactive=False) | |
| with gr.Row(): | |
| measurement_dropdown = gr.Dropdown( | |
| ["Kullback-Leibler Divergence", "L1-Loss", "MSE", "SSIM", "Cosine Similarity", "TV", "Perceptual"], | |
| value="Cosine Similarity", | |
| info="Select the comparison metric.", | |
| label="Comparison Metric", | |
| scale=2 | |
| ) | |
| block_size_slider = gr.Slider( | |
| minimum=2, maximum=64, value=8, step=2, | |
| info="Size of the block/window for comparison.", | |
| label="Block/Window Size", | |
| scale=1 | |
| ) | |
| run_button = gr.Button("๐ Run Comparison", variant="primary") | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("### 3. See Result") | |
| heatmap_display = gr.Image(type="numpy", label="Comparison Heatmap (Difference)", height=350, interactive=False) | |
| # --- Event Listeners --- | |
| # Load initial sample and run comparison when the app starts | |
| demo.load( | |
| fn=load_and_compare_initial, | |
| inputs=[task_dropdown], | |
| outputs=[real_img_display, fake_img_display, heatmap_display, status_message] | |
| ) | |
| sample_button.click( | |
| fn=gather_images, | |
| inputs=[task_dropdown], | |
| outputs=[real_img_display, fake_img_display, status_message] | |
| ) | |
| upload_real_img.upload( | |
| fn=lambda x: x, | |
| inputs=[upload_real_img], | |
| outputs=[real_img_display] | |
| ) | |
| upload_fake_img.upload( | |
| fn=lambda x: x, | |
| inputs=[upload_fake_img], | |
| outputs=[fake_img_display] | |
| ) | |
| run_button.click( | |
| fn=run_comparison, | |
| inputs=[real_img_display, fake_img_display, measurement_dropdown, block_size_slider], | |
| outputs=[heatmap_display, status_message] | |
| ) | |
| task_dropdown.change( | |
| fn=clear_uploads, | |
| inputs=[gr.Textbox(value="Task changed. Please get a new sample.", visible=False)], | |
| outputs=[real_img_display, fake_img_display, status_message] | |
| ) | |
| # --- Application Entry Point --- | |
| if __name__ == "__main__": | |
| print("-------------------------------------------------------------") | |
| print("Verifying VGG19 model status...") | |
| if perceptual_model is None: | |
| print("WARNING: VGG19 model failed to load. 'Perceptual' metric will be unavailable.") | |
| else: | |
| print("VGG19 model loaded successfully.") | |
| print("-------------------------------------------------------------") | |
| print(f"Checking initial dataset path: {PATH}") | |
| if not os.path.isdir(PATH): | |
| print(f"WARNING: Initial dataset path not found: {PATH}") | |
| print(f" Please ensure the directory '{os.path.join('datasets', TASK, 'real')}' exists.") | |
| else: | |
| print("Initial dataset path seems valid.") | |
| print("-------------------------------------------------------------") | |
| print("Launching Gradio App...") | |
| print("Access the app in your browser, usually at: http://127.0.0.1:7860") | |
| print("-------------------------------------------------------------") | |
| demo.launch(share=False, debug=False) | |