import torch import torch.nn.functional as F from torchvision import models, transforms from PIL import Image import gradio as gr import numpy as np from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image # Define labels labels = ['Healthy', 'Autistic', 'NDD'] # Load model model = models.convnext_tiny(weights=None) model.classifier[2] = torch.nn.Sequential( torch.nn.Linear(model.classifier[2].in_features, 512), torch.nn.GELU(), torch.nn.Dropout(p=0.5), torch.nn.Linear(512, 3) ) model.load_state_dict(torch.load("autism_model_weights.pth", map_location=torch.device("cpu"))) model.eval() # Grad-CAM setup target_layers = [model.features[-1]] cam = GradCAM(model=model, target_layers=target_layers) # Image preprocessing def preprocess(img): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) return transform(img).unsqueeze(0) # Prediction function def classify(image): input_tensor = preprocess(image) with torch.no_grad(): outputs = model(input_tensor) probs = F.softmax(outputs[0], dim=0) top_probs, top_idxs = torch.topk(probs, 3) preds = {labels[i]: float(top_probs[j]) for j, i in enumerate(top_idxs)} # Grad-CAM grayscale_cam = cam(input_tensor=input_tensor)[0] rgb_img = np.array(image.resize((224, 224))) / 255.0 cam_img = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) return preds, Image.fromarray(cam_img) # Gradio interface demo = gr.Interface( fn=classify, inputs=gr.Image(type="pil"), outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil")], title="AutismLens", description="Generate new test" ) demo.launch()