Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import gradio as gr | |
| import numpy as np | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| # Define labels | |
| labels = ['Healthy', 'Autistic', 'NDD'] | |
| # Load model | |
| model = models.convnext_tiny(weights=None) | |
| model.classifier[2] = torch.nn.Sequential( | |
| torch.nn.Linear(model.classifier[2].in_features, 512), | |
| torch.nn.GELU(), | |
| torch.nn.Dropout(p=0.5), | |
| torch.nn.Linear(512, 3) | |
| ) | |
| model.load_state_dict(torch.load("autism_model_weights.pth", map_location=torch.device("cpu"))) | |
| model.eval() | |
| # Grad-CAM setup | |
| target_layers = [model.features[-1]] | |
| cam = GradCAM(model=model, target_layers=target_layers) | |
| # Image preprocessing | |
| def preprocess(img): | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]) | |
| ]) | |
| return transform(img).unsqueeze(0) | |
| # Prediction function | |
| def classify(image): | |
| input_tensor = preprocess(image) | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| probs = F.softmax(outputs[0], dim=0) | |
| top_probs, top_idxs = torch.topk(probs, 3) | |
| preds = {labels[i]: float(top_probs[j]) for j, i in enumerate(top_idxs)} | |
| # Grad-CAM | |
| grayscale_cam = cam(input_tensor=input_tensor)[0] | |
| rgb_img = np.array(image.resize((224, 224))) / 255.0 | |
| cam_img = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) | |
| return preds, Image.fromarray(cam_img) | |
| # Gradio interface | |
| demo = gr.Interface( | |
| fn=classify, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil")], | |
| title="AutismLens", | |
| description="Generate new test" | |
| ) | |
| demo.launch() |