Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| """ | |
| Application for ResNet50 trained on ImageNet-1K. | |
| """ | |
| # Standard Library Imports | |
| import gradio as gr | |
| # Third Party Imports | |
| import torch | |
| from torchvision import models | |
| # Local Imports | |
| from inference import inference | |
| def load_model(model_path: str): | |
| """ | |
| Load the model. | |
| """ | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| # Initialize a fresh model without pretrained weights | |
| model = models.resnet50(weights=None) | |
| model = model.to(device) | |
| # Load custom weights | |
| state_dict = torch.load(model_path, map_location=device) | |
| # Debug: Print original state dict keys | |
| print("\nOriginal state dict keys:", list(state_dict['model_state_dict'].keys())[:5]) | |
| # Remove the 'model.' prefix from state dict keys | |
| new_state_dict = {} | |
| for key, value in state_dict['model_state_dict'].items(): | |
| new_key = key.replace('model.', '') | |
| new_state_dict[new_key] = value | |
| # Debug: Print modified state dict keys | |
| print("Modified state dict keys:", list(new_state_dict.keys())[:5]) | |
| print("Model state dict keys:", list(model.state_dict().keys())[:5]) | |
| # Load the modified state dict | |
| try: | |
| model.load_state_dict(new_state_dict) | |
| print("Successfully loaded model weights") | |
| except Exception as e: | |
| print(f"Error loading state dict: {str(e)}") | |
| raise e | |
| model.eval() | |
| return model | |
| def load_classes(): | |
| """ | |
| Load the ImageNet classes | |
| """ | |
| weights = models.ResNet50_Weights.IMAGENET1K_V1 | |
| classes = weights.meta["categories"] | |
| print(f"Loaded {len(classes)} classes") | |
| return classes | |
| def inference_wrapper(image, alpha, top_k, target_layer): | |
| """ | |
| Wrapper function for inference with error handling | |
| """ | |
| try: | |
| if image is None: | |
| return {"Error": 1.0}, None | |
| results = inference( | |
| image, | |
| alpha, | |
| top_k, | |
| target_layer, | |
| model=model, | |
| classes=classes | |
| ) | |
| if results is None: | |
| return {"Error": 1.0}, None | |
| return results | |
| except RuntimeError as e: | |
| error_msg = str(e) | |
| print(f"Error in inference: {error_msg}") | |
| if "out of memory" in error_msg.lower(): | |
| return {"GPU Memory Error - Please try again": 1.0}, None | |
| return {"Runtime Error: " + error_msg: 1.0}, None | |
| except Exception as e: | |
| error_msg = str(e) | |
| print(f"Error in inference: {error_msg}") | |
| return {"Error: " + error_msg: 1.0}, None | |
| def main(): | |
| """ | |
| Main function for the application. | |
| """ | |
| global model, classes | |
| try: | |
| print(f"Gradio version: {gr.__version__}") | |
| # Load the model at startup | |
| model = load_model("resnet50_imagenet1k.pth") | |
| classes = load_classes() | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # ResNet50 trained on ImageNet-1K | |
| A large-scale image classification dataset with 1.2 million training images across 1,000 object categories. | |
| """ | |
| ) | |
| with gr.Tab("Predictions & GradCAM"): | |
| gr.Markdown( | |
| """ | |
| View model predictions and visualize where the model is looking using GradCAM. | |
| ## Steps to use: | |
| 1. Upload an image or select one from the examples below | |
| 2. Adjust the sliders (optional): | |
| - Activation Map Transparency: Controls the blend between original image and activation map | |
| - Number of Top Predictions: How many top class predictions to show | |
| - Target Layer Number: Which network layer to visualize (deeper layers show higher-level features) | |
| 3. Click "Generate GradCAM" to run the model | |
| 4. View the results: | |
| - Left: Original uploaded image | |
| - Right: Model predictions and GradCAM visualization showing where the model focused | |
| """ | |
| ) | |
| # Define inputs | |
| with gr.Row(): | |
| img_input = gr.Image( | |
| label="Input Image", | |
| type="numpy", | |
| height=224, | |
| width=224 | |
| ) | |
| with gr.Column(): | |
| label_output = gr.Label(label="Predictions") | |
| gradcam_output = gr.Image( | |
| label="GradCAM Output", | |
| height=224, | |
| width=224 | |
| ) | |
| with gr.Row(): | |
| alpha_slider = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.5, | |
| step=0.1, | |
| label="Activation Map Transparency" | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=3, | |
| step=1, | |
| label="Number of Top Predictions" | |
| ) | |
| target_layer_slider = gr.Slider( | |
| minimum=1, | |
| maximum=6, | |
| value=4, | |
| step=1, | |
| label="Target Layer Number" | |
| ) | |
| gradcam_button = gr.Button("Generate GradCAM") | |
| # Set up the click event | |
| gradcam_button.click( | |
| fn=inference_wrapper, | |
| inputs=[ | |
| img_input, | |
| alpha_slider, | |
| top_k_slider, | |
| target_layer_slider | |
| ], | |
| outputs=[ | |
| label_output, | |
| gradcam_output | |
| ] | |
| ) | |
| # Examples section for Gradio 5.x | |
| examples = [ | |
| [ | |
| "assets/examples/cat.jpg", | |
| 0.5, | |
| 3, | |
| 4 | |
| ], | |
| [ | |
| "assets/examples/frog.jpg", | |
| 0.5, | |
| 3, | |
| 4 | |
| ], | |
| [ | |
| "assets/examples/bird.jpg", | |
| 0.5, | |
| 3, | |
| 4 | |
| ], | |
| [ | |
| "assets/examples/car.jpg", | |
| 0.5, | |
| 3, | |
| 4 | |
| ], | |
| [ | |
| "assets/examples/truck.jpg", | |
| 0.5, | |
| 3, | |
| 4 | |
| ], | |
| [ | |
| "assets/examples/horse.jpg", | |
| 0.5, | |
| 3, | |
| 4 | |
| ], | |
| [ | |
| "assets/examples/plane.jpg", | |
| 0.5, | |
| 3, | |
| 4 | |
| ], | |
| [ | |
| "assets/examples/ship.png", | |
| 0.5, | |
| 3, | |
| 4 | |
| ] | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[ | |
| img_input, | |
| alpha_slider, | |
| top_k_slider, | |
| target_layer_slider | |
| ], | |
| outputs=[ | |
| label_output, | |
| gradcam_output | |
| ], | |
| fn=inference_wrapper, | |
| cache_examples=False, # Disable caching to prevent memory issues | |
| label="Click on any example to run GradCAM" | |
| ) | |
| # Queue configuration | |
| demo.queue(max_size=1) # Only allow one job at a time | |
| # Launch with minimal memory usage | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True | |
| ) | |
| except Exception as e: | |
| print(f"Error during startup: {str(e)}") | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| if __name__ == "__main__": | |
| main() | |