Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from fastai.vision.all import load_learner, PILImage | |
| import io | |
| # Ensure custom classes exist before loading the model | |
| class Hook: | |
| def __init__(self, module, func): | |
| self.hook = module.register_forward_hook(lambda mod, inp, out: func(out)) | |
| def remove(self): self.hook.remove() | |
| class HookBwd: | |
| def __init__(self, module, func): | |
| self.hook = module.register_full_backward_hook(lambda mod, grad_input, grad_output: func(grad_output[0])) | |
| def remove(self): self.hook.remove() | |
| # Load the learner | |
| try: | |
| learn = load_learner('export.pkl') | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # Function to predict + generate Class Activation Map (CAM) | |
| def predict_with_cam(img): | |
| img = PILImage.create(img) | |
| # Get model and target layer (modify as needed for your architecture) | |
| model = learn.model | |
| target_layer = model[0][-1] # Adjust based on model architecture | |
| activations, gradients = [], [] | |
| # Define hook functions | |
| def hook_activations(out): activations.append(out) | |
| def hook_gradients(grad): gradients.append(grad) | |
| # Attach hooks | |
| h1 = Hook(target_layer, hook_activations) | |
| h2 = HookBwd(target_layer, hook_gradients) | |
| # Run prediction | |
| pred_class, pred_idx, probs = learn.predict(img) | |
| # Perform backward pass for gradients | |
| img_tensor = learn.dls.test_dl([img]).one_batch()[0] | |
| img_tensor.requires_grad_() | |
| output = model(img_tensor) | |
| output[0, pred_idx].backward() | |
| # Remove hooks | |
| h1.remove() | |
| h2.remove() | |
| # Generate Class Activation Map (CAM) | |
| act = activations[0].detach().cpu().squeeze(0) | |
| grad = gradients[0].detach().cpu().squeeze(0) | |
| weights = grad.mean(dim=(1, 2), keepdim=True) | |
| cam = (weights * act).sum(0) | |
| cam = cam.clamp(min=0).numpy() | |
| # Normalize CAM | |
| cam = (cam - cam.min()) / (cam.max() - cam.min()) | |
| # Plot CAM | |
| fig, ax = plt.subplots() | |
| ax.imshow(img) | |
| ax.imshow(cam, alpha=0.5, cmap='jet') | |
| ax.axis('off') | |
| # Save CAM image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) | |
| buf.seek(0) | |
| return {learn.dls.vocab[i]: float(probs[i]) for i in range(len(probs))}, buf | |
| # Create Gradio interface | |
| interface = gr.Interface( | |
| fn=predict_with_cam, | |
| inputs=gr.Image(type='pil'), | |
| outputs=[gr.Label(num_top_classes=3), gr.Image(type='pil')], | |
| title="Image Classifier with CAM" | |
| ) | |
| interface.launch() | |