Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| """ | |
| Application for ResNet50 trained on ImageNet-1K. | |
| """ | |
| # Standard Library Imports | |
| import gradio as gr | |
| # Third Party Imports | |
| import torch | |
| from torchvision import models | |
| # Local Imports | |
| from inference import inference | |
| def load_model(model_path: str): | |
| """ | |
| Load the model. | |
| """ | |
| # Load the pre-trained ResNet50 model from ImageNet | |
| model = models.resnet50(pretrained=False) | |
| # Load custom weights from a .pth file with CPU mapping | |
| state_dict = torch.load(model_path, map_location=torch.device('cpu')) | |
| # Filter out unexpected keys | |
| filtered_state_dict = {k: v for k, v in state_dict['model_state_dict'].items() if k in model.state_dict()} | |
| # Load the filtered state dictionary into the model | |
| model.load_state_dict(filtered_state_dict, strict=False) | |
| model.eval() | |
| return model | |
| def load_classes(): | |
| """ | |
| Load the classes. | |
| """ | |
| # Get ImageNet class names from ResNet50 weights | |
| classes = models.ResNet50_Weights.IMAGENET1K_V2.meta["categories"] | |
| return classes | |
| def main(): | |
| """ | |
| Main function for the application. | |
| """ | |
| # Load the model at startup | |
| model = load_model("resnet50_imagenet1k.pth") | |
| # Load the classes at startup | |
| classes = load_classes() | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # ImageNet-1K trained on ResNet50v2 | |
| """ | |
| ) | |
| # ############################################################################# | |
| # ################################ GradCam Tab ################################ | |
| # ############################################################################# | |
| with gr.Tab("GradCam"): | |
| gr.Markdown( | |
| """ | |
| Visualize Class Activations Maps generated by the model's layer for the predicted class. | |
| This is used to see what the model is actually looking at in the image. | |
| """ | |
| ) | |
| with gr.Row(): | |
| img_input = [gr.Image(label="Input Image", type="numpy", height=224)] | |
| gradcam_outputs = [ | |
| gr.Label(label="Predictions"), | |
| gr.Image(label="GradCAM Output", height=224) | |
| ] | |
| with gr.Row(): | |
| gradcam_inputs = [ | |
| gr.Slider(0, 1, value=0.5, label="Activation Map Transparency"), | |
| gr.Slider(1, 10, value=3, step=1, label="Number of Top Predictions"), | |
| gr.Slider(1, 6, value=4, step=1, label="Target Layer Number") | |
| ] | |
| gradcam_button = gr.Button("Generate GradCAM") | |
| # Pass model to inference function using partial | |
| from functools import partial | |
| inference_fn = partial(inference, model=model, classes=classes) | |
| gradcam_button.click(inference_fn, inputs=img_input + gradcam_inputs, outputs=gradcam_outputs) | |
| gr.Markdown("## Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["./assets/examples/dog.jpg", 0.5, 3, 4], | |
| ["./assets/examples/cat.jpg", 0.5, 3, 4], | |
| ["./assets/examples/frog.jpg", 0.5, 3, 4], | |
| ["./assets/examples/bird.jpg", 0.5, 3, 4], | |
| ["./assets/examples/shark-plane.jpg", 0.5, 3, 4], | |
| ["./assets/examples/car.jpg", 0.5, 3, 4], | |
| ["./assets/examples/truck.jpg", 0.5, 3, 4], | |
| ["./assets/examples/horse.jpg", 0.5, 3, 4], | |
| ["./assets/examples/plane.jpg", 0.5, 3, 4], | |
| ["./assets/examples/ship.png", 0.5, 3, 4] | |
| ], | |
| inputs=img_input + gradcam_inputs, | |
| fn=inference_fn, | |
| outputs=gradcam_outputs | |
| ) | |
| # Launch the demo (moved inside the Blocks context) | |
| demo.launch(debug=True) | |
| if __name__ == "__main__": | |
| main() | |