import gradio as gr import torch import matplotlib.pyplot as plt import numpy as np from fastai.vision.all import load_learner, PILImage import io # Ensure custom classes exist before loading the model class Hook: def __init__(self, module, func): self.hook = module.register_forward_hook(lambda mod, inp, out: func(out)) def remove(self): self.hook.remove() class HookBwd: def __init__(self, module, func): self.hook = module.register_full_backward_hook(lambda mod, grad_input, grad_output: func(grad_output[0])) def remove(self): self.hook.remove() # Load the learner try: learn = load_learner('export.pkl') print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") # Function to predict + generate Class Activation Map (CAM) def predict_with_cam(img): img = PILImage.create(img) # Get model and target layer (modify as needed for your architecture) model = learn.model target_layer = model[0][-1] # Adjust based on model architecture activations, gradients = [], [] # Define hook functions def hook_activations(out): activations.append(out) def hook_gradients(grad): gradients.append(grad) # Attach hooks h1 = Hook(target_layer, hook_activations) h2 = HookBwd(target_layer, hook_gradients) # Run prediction pred_class, pred_idx, probs = learn.predict(img) # Perform backward pass for gradients img_tensor = learn.dls.test_dl([img]).one_batch()[0] img_tensor.requires_grad_() output = model(img_tensor) output[0, pred_idx].backward() # Remove hooks h1.remove() h2.remove() # Generate Class Activation Map (CAM) act = activations[0].detach().cpu().squeeze(0) grad = gradients[0].detach().cpu().squeeze(0) weights = grad.mean(dim=(1, 2), keepdim=True) cam = (weights * act).sum(0) cam = cam.clamp(min=0).numpy() # Normalize CAM cam = (cam - cam.min()) / (cam.max() - cam.min()) # Plot CAM fig, ax = plt.subplots() ax.imshow(img) ax.imshow(cam, alpha=0.5, cmap='jet') ax.axis('off') # Save CAM image buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) buf.seek(0) return {learn.dls.vocab[i]: float(probs[i]) for i in range(len(probs))}, buf # Create Gradio interface interface = gr.Interface( fn=predict_with_cam, inputs=gr.Image(type='pil'), outputs=[gr.Label(num_top_classes=3), gr.Image(type='pil')], title="Image Classifier with CAM" ) interface.launch()