Spaces:
Paused
Paused
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| from tensorflow.keras.applications import ResNet50 | |
| from tensorflow.keras.applications.resnet50 import preprocess_input | |
| from tensorflow.keras.preprocessing import image | |
| from skimage.metrics import structural_similarity as ssim | |
| import os | |
| import tempfile | |
| from PIL import Image | |
| class ImageCharacterClassifier: | |
| def __init__(self, similarity_threshold=0.5): | |
| # Initialize ResNet50 model without top classification layer | |
| self.model = ResNet50(weights='imagenet', include_top=False, pooling='avg') | |
| self.similarity_threshold = similarity_threshold | |
| def load_and_preprocess_image(self, image_path, target_size=(224, 224)): | |
| # Load and preprocess image for ResNet50 | |
| img = image.load_img(image_path, target_size=target_size) | |
| img_array = image.img_to_array(img) | |
| img_array = np.expand_dims(img_array, axis=0) | |
| img_array = preprocess_input(img_array) | |
| return img_array | |
| def extract_features(self, image_path): | |
| # Extract deep features using ResNet50 | |
| preprocessed_img = self.load_and_preprocess_image(image_path) | |
| features = self.model.predict(preprocessed_img) | |
| return features | |
| def calculate_ssim(self, img1_path, img2_path): | |
| # Calculate SSIM between two images | |
| img1 = cv2.imread(img1_path) | |
| img2 = cv2.imread(img2_path) | |
| if img1 is None or img2 is None: | |
| return 0.0 | |
| # Convert to grayscale if images are in color | |
| if len(img1.shape) == 3: | |
| img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) | |
| if len(img2.shape) == 3: | |
| img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY) | |
| # Resize images to same dimensions | |
| img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0])) | |
| score = ssim(img1, img2) | |
| return score | |
| def process_images(reference_image, comparison_images, similarity_threshold): | |
| try: | |
| if reference_image is None: | |
| return "Please upload a reference image.", [] | |
| if not comparison_images: | |
| return "Please upload comparison images.", [] | |
| # Create temporary directory for saving uploaded files | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # Initialize classifier with the threshold | |
| classifier = ImageCharacterClassifier(similarity_threshold=similarity_threshold) | |
| # Save reference image | |
| ref_path = os.path.join(temp_dir, "reference.jpg") | |
| cv2.imwrite(ref_path, cv2.cvtColor(reference_image, cv2.COLOR_RGB2BGR)) | |
| results = [] | |
| html_output = """ | |
| <div style='text-align: center; margin-bottom: 20px;'> | |
| <h2 style='color: #2c3e50;'>Results</h2> | |
| <p style='color: #7f8c8d;'>Reference image compared with uploaded images</p> | |
| </div> | |
| """ | |
| # Extract reference features once | |
| ref_features = classifier.extract_features(ref_path) | |
| # Process each comparison image | |
| for i, comp_image in enumerate(comparison_images): | |
| try: | |
| # Save comparison image | |
| comp_path = os.path.join(temp_dir, f"comparison_{i}.jpg") | |
| try: | |
| # First attempt: Try using PIL | |
| with Image.open(comp_image.name) as img: | |
| img = img.convert('RGB') | |
| img_array = np.array(img) | |
| cv2.imwrite(comp_path, cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)) | |
| except Exception as e1: | |
| print(f"PIL failed: {str(e1)}") | |
| # Second attempt: Try using OpenCV directly | |
| img = cv2.imread(comp_image.name) | |
| if img is not None: | |
| cv2.imwrite(comp_path, img) | |
| else: | |
| raise ValueError(f"Could not read image: {comp_image.name}") | |
| # Calculate SSIM for structural similarity | |
| ssim_score = classifier.calculate_ssim(ref_path, comp_path) | |
| # Extract features for physical feature comparison | |
| comp_features = classifier.extract_features(comp_path) | |
| # Calculate feature differences for physical features | |
| feature_diff = np.abs(ref_features - comp_features) | |
| # Calculate different aspects of similarity | |
| avg_feature_diff = np.mean(feature_diff) | |
| max_feature_diff = np.max(feature_diff) | |
| feature_similarity = np.dot(ref_features.flatten(), | |
| comp_features.flatten()) / ( | |
| np.linalg.norm(ref_features) * np.linalg.norm(comp_features)) | |
| # Stricter similarity criteria | |
| is_similar = True # Start with assumption of similarity | |
| reason = "Images are similar" | |
| # First check for major physical feature differences (like misplaced eyes) | |
| if max_feature_diff > 0.85 or avg_feature_diff > 0.5: | |
| is_similar = False | |
| reason = "Major physical differences detected (missing or misplaced features)" | |
| # Then check for overall structural similarity | |
| elif ssim_score < 0.4: # Lowered SSIM threshold | |
| is_similar = False | |
| reason = "Overall structure is too different" | |
| # Finally check for feature similarity | |
| elif feature_similarity < 0.5: | |
| is_similar = False | |
| reason = "Features don't match well enough" | |
| # Debug information | |
| print(f"\nDebug for {os.path.basename(comp_image.name)}:") | |
| print(f"SSIM Score: {ssim_score:.3f}") | |
| print(f"Max Feature Difference: {max_feature_diff:.3f}") | |
| print(f"Average Feature Difference: {avg_feature_diff:.3f}") | |
| print(f"Feature Similarity: {feature_similarity:.3f}") | |
| # Create HTML output with improved styling and reason | |
| status_color = "#27ae60" if is_similar else "#c0392b" # Green or Red | |
| status_text = "SIMILAR" if is_similar else "NOT SIMILAR" | |
| status_icon = "✓" if is_similar else "✗" | |
| html_output += f""" | |
| <div style=' | |
| margin: 15px 0; | |
| padding: 15px; | |
| border-radius: 8px; | |
| background-color: {status_color}1a; | |
| border: 2px solid {status_color}; | |
| display: flex; | |
| align-items: center; | |
| justify-content: space-between; | |
| '> | |
| <div style='display: flex; align-items: center;'> | |
| <span style=' | |
| font-size: 24px; | |
| margin-right: 10px; | |
| color: {status_color}; | |
| '>{status_icon}</span> | |
| <div> | |
| <span style='color: #2c3e50; font-weight: bold; display: block;'> | |
| {os.path.basename(comp_image.name)} | |
| </span> | |
| <span style='color: {status_color}; font-size: 12px;'> | |
| {reason} | |
| </span> | |
| </div> | |
| </div> | |
| <div style=' | |
| color: {status_color}; | |
| font-weight: bold; | |
| font-size: 16px; | |
| '>{status_text}</div> | |
| </div> | |
| """ | |
| # Read the processed image back for display | |
| display_img = cv2.imread(comp_path) | |
| if display_img is not None: | |
| display_img = cv2.cvtColor(display_img, cv2.COLOR_BGR2RGB) | |
| results.append(display_img) | |
| except Exception as e: | |
| print(f"Error processing {comp_image.name}: {str(e)}") | |
| html_output += f""" | |
| <div style=' | |
| margin: 15px 0; | |
| padding: 15px; | |
| border-radius: 8px; | |
| background-color: #e74c3c1a; | |
| border: 2px solid #e74c3c; | |
| '> | |
| <h3 style='color: #e74c3c; margin: 0;'> | |
| Error processing: {os.path.basename(comp_image.name)} | |
| </h3> | |
| <p style='color: #e74c3c; margin: 5px 0 0 0;'>{str(e)}</p> | |
| </div> | |
| """ | |
| return html_output, results | |
| except Exception as e: | |
| print(f"Main error: {str(e)}") | |
| return f""" | |
| <div style=' | |
| padding: 15px; | |
| border-radius: 8px; | |
| background-color: #e74c3c1a; | |
| border: 2px solid #e74c3c; | |
| '> | |
| <h3 style='color: #e74c3c; margin: 0;'>Error</h3> | |
| <p style='color: #e74c3c; margin: 5px 0 0 0;'>{str(e)}</p> | |
| </div> | |
| """, [] | |
| # Update the interface creation | |
| def create_interface(): | |
| with gr.Blocks() as interface: | |
| gr.Markdown("# Image Similarity Classifier") | |
| gr.Markdown("Upload a reference image and up to 10 comparison images to check similarity.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| reference_input = gr.Image( | |
| label="Reference Image", | |
| type="numpy", | |
| image_mode="RGB" | |
| ) | |
| comparison_input = gr.File( | |
| label="Comparison Images (Upload up to 10)", | |
| file_count="multiple", | |
| file_types=["image"], | |
| maximum=10 | |
| ) | |
| threshold_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.5, | |
| step=0.05, | |
| label="Similarity Threshold" | |
| ) | |
| submit_button = gr.Button("Compare Images", variant="primary") | |
| with gr.Column(): | |
| output_html = gr.HTML(label="Results") | |
| output_gallery = gr.Gallery( | |
| label="Processed Images", | |
| columns=5, | |
| show_label=True, | |
| height="auto" | |
| ) | |
| submit_button.click( | |
| fn=process_images, | |
| inputs=[reference_input, comparison_input, threshold_slider], | |
| outputs=[output_html, output_gallery] | |
| ) | |
| return interface | |
| # Launch the app | |
| if __name__ == "__main__": | |
| interface = create_interface() | |
| interface.launch(share=True) |