Spaces:
Paused
Paused
| import os | |
| import numpy as np | |
| import cv2 | |
| import gradio as gr | |
| from tensorflow.keras.applications import ResNet50 | |
| from tensorflow.keras.applications.resnet50 import preprocess_input | |
| from skimage.metrics import structural_similarity as ssim | |
| from PIL import Image | |
| from io import BytesIO | |
| # Disable GPU for TensorFlow | |
| os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
| # --- DOCUMENTATION STRINGS (English Only) --- | |
| GUIDELINE_SETUP = """ | |
| ## 1. Quick Start Guide: Setup and Run Instructions | |
| This application uses a combination of advanced feature extraction (ResNet50) and structural analysis (SSIM) to determine if comparison images are structurally and semantically similar to a reference image. | |
| 1. **Upload Reference:** Upload the main image you want to compare against in the 'Reference Image' box. | |
| 2. **Upload Comparisons:** Upload one or more images you want to test for similarity in the 'Comparison Images' file upload area. | |
| 3. **Set Threshold:** Adjust the 'Similarity Threshold' slider. This controls the sensitivity for structural similarity (SSIM). | |
| 4. **Run:** Click the **"Compare Images"** button. | |
| 5. **Review:** Results will appear in the 'Results' panel, indicating if each comparison image is "SIMILAR" or "NOT SIMILAR". | |
| """ | |
| GUIDELINE_INPUT = """ | |
| ## 2. Expected Inputs and Preprocessing | |
| | Input Field | Purpose | Requirement | | |
| | :--- | :--- | :--- | | |
| | **Reference Image** | The baseline image against which all others will be compared. | Must be a single image file (JPG, PNG). | | |
| | **Comparison Images** | One or more images to be tested for similarity. | Must be multiple image files. Upload them using the file selector. | | |
| | **Similarity Threshold** | A slider controlling the sensitivity (0.0 to 1.0) for structural similarity (SSIM). | Higher values (closer to 1.0) mean stricter similarity requirements. Default is 0.5. | | |
| **Image Preprocessing:** All uploaded images are automatically resized to 224x224 pixels and standardized according to the requirements of the ResNet model before feature extraction. | |
| """ | |
| GUIDELINE_OUTPUT = """ | |
| ## 3. Expected Outputs (Similarity Results) | |
| The application provides two main outputs: | |
| 1. **Results (HTML Panel):** | |
| * A list detailing the outcome for each comparison image. | |
| * Status: **SIMILAR** (Green) or **NOT SIMILAR** (Red). | |
| 2. **Processed Images (Gallery):** | |
| * A gallery displaying the input comparison images after they have been processed. | |
| ### How Similarity is Determined: | |
| The classification relies on two checks: Structural Similarity (SSIM) and Deep Feature Distance (ResNet). An image is marked "SIMILAR" if both structural and semantic properties suggest a close match. | |
| """ | |
| # --- CLASSIFIER CLASS --- | |
| class ImageCharacterClassifier: | |
| def __init__(self, similarity_threshold=0.5): | |
| self.model = ResNet50(weights='imagenet', include_top=False, pooling='avg') | |
| self.similarity_threshold = similarity_threshold | |
| def load_and_preprocess_image(self, img): | |
| img = img.convert('RGB') | |
| img_array = np.array(img) | |
| img_array = cv2.resize(img_array, (224, 224)) | |
| img_array = np.expand_dims(img_array, axis=0) | |
| img_array = preprocess_input(img_array) | |
| return img_array | |
| def extract_features(self, img): | |
| preprocessed_img = self.load_and_preprocess_image(img) | |
| features = self.model.predict(preprocessed_img, verbose=0) | |
| return features | |
| def calculate_ssim(self, img1, img2): | |
| img1_gray = cv2.cvtColor(img1, cv2.COLOR_RGB2GRAY) | |
| img2_gray = cv2.cvtColor(img2, cv2.COLOR_RGB2GRAY) | |
| img2_gray = cv2.resize(img2_gray, (img1_gray.shape[1], img1_gray.shape[0])) | |
| return ssim(img1_gray, img2_gray, data_range=img1_gray.max() - img1_gray.min()) | |
| def process_images(reference_image_array, comparison_files, similarity_threshold): | |
| try: | |
| if reference_image_array is None: | |
| return "<p style='color:red;'>Please upload a reference image.</p>", [] | |
| if not comparison_files: | |
| return "<p style='color:red;'>Please upload comparison images.</p>", [] | |
| classifier = ImageCharacterClassifier(similarity_threshold) | |
| ref_image_pil = Image.fromarray(reference_image_array).convert("RGB") | |
| ref_features = classifier.extract_features(ref_image_pil) | |
| ref_image_for_ssim = cv2.cvtColor(reference_image_array, cv2.COLOR_BGR2RGB) | |
| results = [] | |
| html_output = "<h3>Comparison Results:</h3>" | |
| for comp_file_item in comparison_files: | |
| try: | |
| # FIX: Extract file path correctly regardless of whether it's a dict (internal Gradio) | |
| # or a gr.File object (returned by our custom loader function). | |
| if isinstance(comp_file_item, str): | |
| file_path = comp_file_item | |
| elif hasattr(comp_file_item, 'name'): | |
| file_path = comp_file_item.name | |
| elif isinstance(comp_file_item, dict) and 'name' in comp_file_item: | |
| file_path = comp_file_item['name'] | |
| else: | |
| raise ValueError("Invalid file object structure.") | |
| with open(file_path, "rb") as f: | |
| comp_pil = Image.open(BytesIO(f.read())).convert("RGB") | |
| comp_array = np.array(comp_pil) | |
| # SSIM Check | |
| ssim_score = classifier.calculate_ssim(ref_image_for_ssim, comp_array) | |
| # Feature Check | |
| comp_features = classifier.extract_features(comp_pil) | |
| max_feature_diff = np.max(np.abs(ref_features - comp_features)) | |
| feature_match = max_feature_diff < 6.0 | |
| is_similar = feature_match # Primary criterion | |
| status_text = f"SIMILAR (SSIM: {ssim_score:.3f})" if is_similar else f"NOT SIMILAR (SSIM: {ssim_score:.3f})" | |
| status_color = "green" if is_similar else "red" | |
| html_output += f"<p style='color:{status_color};'>{os.path.basename(file_path)}: {status_text}</p>" | |
| results.append(comp_array) | |
| except Exception as e: | |
| # Use the path for logging the error | |
| error_name = os.path.basename(file_path) if 'file_path' in locals() else 'Unknown File' | |
| html_output += f"<p style='color:red;'>Error processing {error_name}: {str(e)}</p>" | |
| return html_output, [r for r in results if r is not None] | |
| except Exception as e: | |
| return f"<p style='color:red;'>Critical Error: {str(e)}</p>", [] | |
| # --- SAMPLE DATA DEFINITION --- | |
| # Placeholder file paths (MUST EXIST for examples to work) | |
| # NOTE: Adjusted paths to match your provided snippet structure 'sample_data/filename' | |
| SAMPLE_FILES_SET1 = { | |
| "reference": "sample_data/license3.jpg", | |
| "comparisons": ["sample_data/license3.jpg", "sample_data/license3.jpg", "sample_data/licence.jpeg"] | |
| } | |
| SAMPLE_FILES_SET2 = { | |
| "reference": "sample_data/licence.jpeg", | |
| "comparisons": ["sample_data/licence.jpeg", "sample_data/license3.jpg", "sample_data/licence.jpeg", "sample_data/licence.jpeg"] | |
| } | |
| # --- GRADIO UI SETUP --- | |
| def create_interface(): | |
| with gr.Blocks(title="Image Similarity Classifier") as interface: | |
| gr.Markdown("# Image Similarity Classifier (ResNet + SSIM)") | |
| gr.Markdown("Tool to compare a reference image against multiple comparison images based on structural and deep feature similarity.") | |
| # 1. Guidelines Section | |
| with gr.Accordion("User Guidelines and Documentation", open=False): | |
| gr.Markdown(GUIDELINE_SETUP) | |
| gr.Markdown("---") | |
| gr.Markdown(GUIDELINE_INPUT) | |
| gr.Markdown("---") | |
| gr.Markdown(GUIDELINE_OUTPUT) | |
| gr.Markdown("---") | |
| # 2. Application Interface | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Step 1: Upload a Reference Image ") | |
| reference_input = gr.Image(label="Reference Image", type="numpy", height=300) | |
| gr.Markdown("## Step 2: Upload Multiple Images to Compair with Reference Image ") | |
| comparison_input = gr.Files(label="Comparison Images", type="file") | |
| gr.Markdown("## Step 3: Set the Confidence Score (Optional) ") | |
| threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Similarity Threshold (SSIM)") | |
| gr.Markdown("## Step 4: Click Compare Images ") | |
| submit_button = gr.Button("Compare Images", variant="primary") | |
| gr.Markdown("---") | |
| gr.Markdown("# Results ") | |
| gr.Markdown("## Comparison Result ") | |
| output_html = gr.HTML(label="Comparison Results") | |
| gr.Markdown("## Processed Comparison Images") | |
| output_gallery = gr.Gallery(label="Processed Comparison Images", columns=3) | |
| # 3. Example Loading Setup | |
| gr.Markdown("---") | |
| gr.Markdown("## Sample Data for Testing") | |
| gr.Markdown("### Click on any of these two set to run the test set ") | |
| def load_and_run_set(reference_path, comparison_paths, threshold_value=0.5): | |
| """Loads data into inputs, triggers processing, and returns all results.""" | |
| # 1. Load Reference Image as NumPy array | |
| ref_img_pil = Image.open(reference_path).convert("RGB") | |
| ref_img_array = np.array(ref_img_pil) | |
| # 2. Comparison Files: Prepare the list of paths (strings) for the processor | |
| # We return a list of strings/paths here, which Gradio's gr.Files component accepts | |
| comparison_file_paths = comparison_paths | |
| # 3. Process the images immediately using the paths | |
| html, gallery = process_images(ref_img_array, comparison_file_paths, threshold_value) | |
| # 4. Return inputs and outputs for component update | |
| return ref_img_array, comparison_file_paths, threshold_value, html, gallery | |
| with gr.Row(): | |
| btn_set1 = gr.Button("Load & Run Sample Set 1 (Similar Docs)", size="sm") | |
| btn_set2 = gr.Button("Load & Run Sample Set 2 (Dissimilar Docs)", size="sm") | |
| # 4. Event Handling | |
| submit_button.click( | |
| fn=process_images, | |
| inputs=[reference_input, comparison_input, threshold_slider], | |
| outputs=[output_html, output_gallery] | |
| ) | |
| # Event handlers for example buttons: load data into inputs/outputs | |
| btn_set1.click( | |
| fn=lambda: load_and_run_set(SAMPLE_FILES_SET1['reference'], SAMPLE_FILES_SET1['comparisons'], 0.6), | |
| inputs=[], | |
| outputs=[reference_input, comparison_input, threshold_slider, output_html, output_gallery] | |
| ) | |
| btn_set2.click( | |
| fn=lambda: load_and_run_set(SAMPLE_FILES_SET2['reference'], SAMPLE_FILES_SET2['comparisons'], 0.4), | |
| inputs=[], | |
| outputs=[reference_input, comparison_input, threshold_slider, output_html, output_gallery] | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| # Ensure the 'sample_data/' directory exists with 'license3.jpg' and 'licence.jpeg' | |
| # and any other necessary files. | |
| interface = create_interface() | |
| interface.queue() | |
| interface.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) |