Spaces:
Sleeping
Sleeping
| import os | |
| import ssl | |
| import cv2 | |
| import torch | |
| import certifi | |
| import numpy as np | |
| import gradio as gr | |
| import torch.nn as nn | |
| from torchvision import models | |
| import torch.nn.functional as F | |
| import matplotlib.pyplot as plt | |
| from PIL import Image, ImageEnhance | |
| import torchvision.transforms as transforms | |
| ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where()) | |
| # Set device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Number of classes | |
| num_classes = 6 | |
| # Load the pre-trained ResNet model | |
| model = models.resnet152(pretrained=True) | |
| for param in model.parameters(): | |
| param.requires_grad = False # Freeze feature extractor | |
| # Modify the classifier for 6 classes with an additional hidden layer | |
| model.fc = nn.Sequential( | |
| nn.Linear(model.fc.in_features, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, num_classes) | |
| ) | |
| # Load trained weights | |
| model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu'))) | |
| model.eval() | |
| # Class labels | |
| class_labels = ['bird', 'cat', 'deer', 'dog', 'frog', 'horse'] | |
| # Image transformation function | |
| def transform_image(image): | |
| """Preprocess the input image.""" | |
| transform = transforms.Compose([ | |
| transforms.Resize((32, 32)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| return img_tensor | |
| # Apply feature filters | |
| def apply_filters(image, brightness, contrast, hue): | |
| """Adjust Brightness, Contrast, and Hue of the input image.""" | |
| image = image.convert("RGB") # Ensure RGB mode | |
| # Adjust brightness | |
| enhancer = ImageEnhance.Brightness(image) | |
| image = enhancer.enhance(brightness) | |
| # Adjust contrast | |
| enhancer = ImageEnhance.Contrast(image) | |
| image = enhancer.enhance(contrast) | |
| # Adjust hue (convert to HSV, modify, and convert back) | |
| image = np.array(image) | |
| hsv_image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32) | |
| hsv_image[..., 0] = (hsv_image[..., 0] + hue * 180) % 180 # Adjust hue | |
| image = cv2.cvtColor(hsv_image.astype(np.uint8), cv2.COLOR_HSV2RGB) | |
| return Image.fromarray(image) | |
| # Superimposition function | |
| def superimpose_images(base_image, overlay_image, alpha): | |
| """Superimpose overlay_image onto base_image with a given alpha blend.""" | |
| if overlay_image is None: | |
| return base_image # No overlay, return base image as is | |
| # Resize overlay image to match base image | |
| overlay_image = overlay_image.resize(base_image.size) | |
| # Convert to numpy arrays | |
| base_array = np.array(base_image).astype(float) | |
| overlay_array = np.array(overlay_image).astype(float) | |
| # Blend images | |
| blended_array = (1 - alpha) * base_array + alpha * overlay_array | |
| blended_array = np.clip(blended_array, 0, 255).astype(np.uint8) | |
| return Image.fromarray(blended_array) | |
| # Prediction function | |
| def predict(image, brightness, contrast, hue, overlay_image, alpha): | |
| """Apply filters, superimpose, classify image, and visualize results.""" | |
| if image is None: | |
| return None, None, None | |
| # Apply feature filters | |
| processed_image = apply_filters(image, brightness, contrast, hue) | |
| # Superimpose overlay image | |
| final_image = superimpose_images(processed_image, overlay_image, alpha) | |
| # Convert PIL Image to Tensor | |
| image_tensor = transform_image(final_image) | |
| with torch.no_grad(): | |
| output = model(image_tensor) | |
| probabilities = F.softmax(output, dim=1).cpu().numpy()[0] | |
| # Generate Bar Chart | |
| with plt.xkcd(): | |
| fig, ax = plt.subplots(figsize=(5, 3)) | |
| ax.bar(class_labels, probabilities, color='skyblue') | |
| ax.set_ylabel("Probability") | |
| ax.set_title("Class Probabilities") | |
| ax.set_ylim([0, 1]) | |
| for i, v in enumerate(probabilities): | |
| ax.text(i, v + 0.02, f"{v:.2f}", ha='center', fontsize=10) | |
| return final_image, fig | |
| # Gradio Interface | |
| with gr.Blocks() as interface: | |
| gr.Markdown("<h2 style='text-align: center;'>Image Classifier with Superimposition & Adjustable Filters</h2>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Base Image") | |
| overlay_input = gr.Image(type="pil", label="Upload Overlay Image (Optional)") | |
| brightness = gr.Slider(0.5, 2.0, value=1.0, label="Brightness") | |
| contrast = gr.Slider(0.5, 2.0, value=1.0, label="Contrast") | |
| hue = gr.Slider(-0.5, 0.5, value=0.0, label="Hue") | |
| alpha = gr.Slider(0.0, 1.0, value=0.5, label="Overlay Weight (Alpha)") | |
| with gr.Column(): | |
| processed_image = gr.Image(label="Final Processed Image") | |
| bar_chart = gr.Plot(label="Class Probabilities") | |
| inputs = [image_input, brightness, contrast, hue, overlay_input, alpha] | |
| outputs = [processed_image, bar_chart] | |
| # Event listeners for real-time updates | |
| image_input.change(predict, inputs=inputs, outputs=outputs) | |
| overlay_input.change(predict, inputs=inputs, outputs=outputs) | |
| brightness.change(predict, inputs=inputs, outputs=outputs) | |
| contrast.change(predict, inputs=inputs, outputs=outputs) | |
| hue.change(predict, inputs=inputs, outputs=outputs) | |
| alpha.change(predict, inputs=inputs, outputs=outputs) | |
| interface.launch() |